# -*- coding: utf-8 -*-
"""
Created on Sun Apr  2 18:12:40 2023

@author: 29672
"""
import os
import torch
import pytorch_lightning as pl
from transformers import BertTokenizer, BertModel
from models.QAModel import QAData, QAModel
from utils.DigitEmbedding import DigitEmbedding

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

max_seq_length = 500
batch_size = 5
num_digits = 6 

# 导入模型
model = QAModel(max_seq_length=max_seq_length,
                input_max_num_sequences=batch_size,
                num_digits=num_digits)
checkpoint_path = './run/best_model.pt'

if os.path.exists(checkpoint_path):
    print('load checkpoint')
    model.load_state_dict(torch.load(checkpoint_path))
model.eval()

# 加载数字编码器
digit_embedding = DigitEmbedding(num_digits, max_seq_length)

while True:
    # 获取用户输入
    question = input("请输入问题：")
    context = input("请输入上下文：")

    # 对输入进行分词和编码
    inputs = tokenizer(question, context, return_tensors='pt', padding=True, truncation=True, max_length=max_seq_length)

    # 编码数字
    input_ids = digit_embedding.encode(inputs['input_ids'][0].tolist())

    # 重塑形状为 (batch_size, seq_length, num_digits)
    input_ids = torch.tensor(input_ids).unsqueeze(0)

    # 生成mask
    attention_mask = torch.ones((1, input_ids.shape[1]), dtype=torch.long)

    # 预测答案
    with torch.no_grad():
        start_logits, end_logits = model(input_ids, attention_mask)

    # 解码数字编码的答案
    start_index = torch.argmax(start_logits).item()
    end_index = torch.argmax(end_logits).item()
    answer_ids = input_ids[0][start_index:end_index+1].tolist()
    decoded = digit_embedding.decode(answer_ids)
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(decoded))

    # 输出答案
    print("回答：", answer)
